Draft

Amortized Bayesian Inference with PyTorch

pytorch
variational auto-encoders
amortized bayesian inference
Published

July 25, 2025

Heuristics in Latent Space
The cost of generating new sample data can be prohibitive. There is a secondary but different cost which attaches to the ‘construction’ of novel data. Principal Components Analysis can be seen as a technique to optimally reconstruct a complex multivariate data set from a lower level compressed dimensional space. Variational auto-encoders allow us to achieve yet more flexible reconstruction results in non-linear cases. Drawing a new sample from the posterior predictive distribution of Bayesian models similarly supplies us with insight in the variability of realised data. Both methods assume a latent model of the data generating process that aims to leverage a compressed representation of the data. These are different heuristics with different consequences for how we understand the variability in the world. Amortized Bayesian inference seeks to unite the two heuristics.

Reconstruction Error

It’s natural to seek short cuts

import torch
import torchvision.datasets as dsets
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import pymc as pm 

Job Satisfaction Data

import numpy as np

# Standard deviations
stds = np.array([0.939, 1.017, 0.937, 0.562, 0.760, 0.524, 
                 0.585, 0.609, 0.731, 0.711, 1.124, 1.001])

n = len(stds)

# Lower triangular correlation values as a flat list
corr_values = [
    1.000,
    .668, 1.000,
    .635, .599, 1.000,
    .263, .261, .164, 1.000,
    .290, .315, .247, .486, 1.000,
    .207, .245, .231, .251, .449, 1.000,
   -.206, -.182, -.195, -.309, -.266, -.142, 1.000,
   -.280, -.241, -.238, -.344, -.305, -.230,  .753, 1.000,
   -.258, -.244, -.185, -.255, -.255, -.215,  .554,  .587, 1.000,
    .080,  .096,  .094, -.017,  .151,  .141, -.074, -.111,  .016, 1.000,
    .061,  .028, -.035, -.058, -.051, -.003, -.040, -.040, -.018,  .284, 1.000,
    .113,  .174,  .059,  .063,  .138,  .044, -.119, -.073, -.084,  .563,  .379, 1.000
]

# Fill correlation matrix
corr_matrix = np.zeros((n, n))
idx = 0
for i in range(n):
    for j in range(i+1):
        corr_matrix[i, j] = corr_values[idx]
        corr_matrix[j, i] = corr_values[idx]
        idx += 1

# Covariance matrix: Sigma = D * R * D
cov_matrix = np.outer(stds, stds) * corr_matrix
#cov_matrix_test = np.dot(np.dot(np.diag(stds), corr_matrix), np.diag(stds))
columns=["JW1","JW2","JW3", "UF1","UF2","FOR", "DA1","DA2","DA3", "EBA","ST","MI"]
corr_df = pd.DataFrame(corr_matrix, columns=columns)

cov_df = pd.DataFrame(cov_matrix, columns=columns)
cov_df

def make_sample(cov_matrix, size, columns):
    sample_df = pd.DataFrame(np.random.multivariate_normal([0]*12, cov_matrix, size=size), columns=columns)
    return sample_df

sample_df = make_sample(cov_matrix, 263, columns)
sample_df.head()
JW1 JW2 JW3 UF1 UF2 FOR DA1 DA2 DA3 EBA ST MI
0 -1.241607 -0.218270 -1.101984 -0.092699 -0.799012 0.174001 -0.370272 -0.551770 -0.343902 1.277332 0.947278 0.495876
1 0.635936 -0.614378 0.382529 1.100494 1.126426 1.081677 0.292383 -0.173557 -0.208669 1.073014 0.355770 0.256730
2 -0.551499 -2.280220 -1.304824 -0.506253 -0.047927 -0.113652 1.094697 0.889694 0.986469 -0.734002 -0.599533 -0.168585
3 1.054921 1.173777 -0.138497 -0.186385 -1.286099 -0.618353 1.360622 1.045441 1.339666 -0.369791 0.451579 -1.216746
4 -0.886769 0.124020 -0.337509 0.162294 0.222886 0.051055 -0.122539 -0.880796 -0.976451 1.344133 -0.079103 0.505976
data = sample_df.corr()

def plot_heatmap(data, title="Correlation Matrix",  vmin=-.2, vmax=.2, ax=None, figsize=(10, 6), colorbar=True):
    data_matrix = data.values
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    im = ax.imshow(data, cmap='viridis', vmin=vmin, vmax=vmax)

    for i in range(data_matrix.shape[0]):
        for j in range(data_matrix.shape[1]):
            text = ax.text(
                j, i,                      # x, y coordinates
                f"{data_matrix[i, j]:.2f}",       # text to display
                ha="center", va="center",  # center alignment
                color="white" if data_matrix[i,j] < 0.5 else "black"  # contrast color
            )

    ax.set_title(title)
    ax.set_xticklabels(data.columns)  
    ax.set_xticks(np.arange(data.shape[1]))
    ax.set_yticklabels(data.index)  
    ax.set_yticks(np.arange(data.shape[0]))
    if colorbar:
        plt.colorbar(im)

plot_heatmap(data)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

X = make_sample(cov_matrix, 100, columns=columns)
U, S, VT = np.linalg.svd(X, full_matrices=False)
ranks = [2, 5, 12]
reconstructions = []
for k in ranks:
    X_k = U[:, :k] @ np.diag(S[:k]) @ VT[:k, :]
    reconstructions.append(X_k)

# Plot original and reconstructed matrices
fig, axes = plt.subplots(1, len(ranks) + 1, figsize=(10,15))
axes[0].imshow(X, cmap='viridis')
axes[0].set_title("Original")
axes[0].axis("off")

for ax, k, X_k in zip(axes[1:], ranks, reconstructions):
    ax.imshow(X_k, cmap='viridis')
    ax.set_title(f"Rank {k}")
    ax.axis("off")

plt.suptitle("Reconstruction of Data Using SVD \n various truncation options",fontsize=12, x=.5, y=1.01)
plt.tight_layout()
plt.show()

Variational Auto-Encoders

scaler = StandardScaler()
X = scaler.fit_transform(sample_df)
X_train, X_test = train_test_split(X, test_size=0.2, random_state=42)

X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
class NumericVAE(nn.Module):
    def __init__(self, n_features, hidden_dim=64, latent_dim=8):
        super().__init__()
        
        # ---------- ENCODER ----------
        # First layer: compress input features into a hidden representation
        self.fc1 = nn.Linear(n_features, hidden_dim)
        
        # Latent space parameters (q(z|x)): mean and log-variance
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)       # μ(x)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)   # log(σ^2(x))
        
        # ---------- DECODER ----------
        # First layer: map latent variable z back into hidden representation
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        
        # Output distribution parameters for reconstruction p(x|z)
        # For numeric data, we predict both mean and log-variance per feature
        self.fc_out_mu = nn.Linear(hidden_dim, n_features)        # μ_x(z)
        self.fc_out_logvar = nn.Linear(hidden_dim, n_features)    # log(σ^2_x(z))

    # ENCODER forward pass: input x -> latent mean, log-variance
    def encode(self, x):
        h = F.relu(self.fc1(x))       # Hidden layer with ReLU
        mu = self.fc_mu(h)            # Latent mean vector
        logvar = self.fc_logvar(h)    # Latent log-variance vector
        return mu, logvar

    # Reparameterization trick: sample z = μ + σ * ε  (ε ~ N(0,1))
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)   # σ = exp(0.5 * logvar)
        eps = torch.randn_like(std)     # ε ~ N(0, I)
        return mu + eps * std           # z = μ + σ * ε

    # DECODER forward pass: latent z -> reconstructed mean, log-variance
    def decode(self, z):
        h = F.relu(self.fc2(z))             # Hidden layer with ReLU
        recon_mu = self.fc_out_mu(h)        # Mean of reconstructed features
        recon_logvar = self.fc_out_logvar(h)# Log-variance of reconstructed features
        return recon_mu, recon_logvar

    # Full forward pass: input x -> reconstructed (mean, logvar), latent params
    def forward(self, x):
        mu, logvar = self.encode(x)            # q(z|x)
        z = self.reparameterize(mu, logvar)    # Sample z from q(z|x)
        recon_mu, recon_logvar = self.decode(z)# p(x|z)
        return (recon_mu, recon_logvar), mu, logvar

    # Sample new synthetic data: z ~ N(0,I), decode to x
    def generate(self, n_samples=100):
        self.eval()
        with torch.no_grad():
            # Sample z from standard normal prior
            z = torch.randn(n_samples, self.fc_mu.out_features)
            
            # Decode to get reconstruction distribution parameters
            cont_mu, cont_logvar = self.decode(z)
            
            # Sample from reconstructed Gaussian: μ_x + σ_x * ε
            return cont_mu + torch.exp(0.5 * cont_logvar) * torch.randn_like(cont_mu)
def vae_loss(x, mu_out, logvar_out, mu, logvar):
    recon = -0.5 * torch.sum(logvar_out + (x - mu_out)**2 / torch.exp(logvar_out))
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return -(recon - kl)
vae = NumericVAE(n_features=X_train.shape[1])
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

train_loader = torch.utils.data.DataLoader(X_train, batch_size=32, shuffle=True)

for epoch in range(3000):
    vae.train()
    total_loss = 0
    for x in train_loader:
        (mu_out, logvar_out), mu, logvar = vae(x)
        loss = vae_loss(x, mu_out, logvar_out, mu, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader.dataset):.4f}")
vae.eval()
synthetic_data = vae.generate(n_samples=263)
synthetic_data = scaler.inverse_transform(synthetic_data.detach().numpy())

synthetic_df = pd.DataFrame(synthetic_data, columns=columns)
synthetic_df.head()
JW1 JW2 JW3 UF1 UF2 FOR DA1 DA2 DA3 EBA ST MI
0 -0.228035 -0.309410 0.265740 -0.448498 -0.915894 0.969787 0.307723 0.328543 -0.070029 -0.436048 -0.536932 -1.274912
1 -1.148540 -1.139884 -0.310325 -0.104545 -0.974271 -0.017864 0.179623 -0.316703 -0.202298 0.928373 0.981866 -0.187083
2 0.384822 1.511461 0.550769 -0.092369 0.339347 0.163415 0.479706 0.283485 0.816934 -0.285522 -1.234980 -0.156891
3 -0.655432 -0.915105 -0.833992 -1.062292 -2.098374 -0.640139 0.843921 0.602708 0.548233 -1.366834 -0.263984 -1.053080
4 -0.534100 0.136500 -0.637269 1.292988 -0.329050 -0.212850 -0.587650 -0.555491 0.407691 -0.391056 -0.810426 0.731710
sample_df.corr() - synthetic_df.corr()
JW1 JW2 JW3 UF1 UF2 FOR DA1 DA2 DA3 EBA ST MI
JW1 0.000000 0.028921 0.005057 0.063538 -0.004431 -0.017299 -0.106423 -0.052280 -0.013757 -0.095653 -0.045535 -0.027609
JW2 0.028921 0.000000 0.031469 -0.005387 -0.127076 0.016175 -0.060252 -0.009769 0.036346 -0.165151 -0.020811 -0.115360
JW3 0.005057 0.031469 0.000000 -0.065580 -0.078823 -0.031757 -0.094614 -0.053747 -0.018999 -0.169537 -0.032128 -0.165671
UF1 0.063538 -0.005387 -0.065580 0.000000 -0.029049 -0.067482 0.035595 0.052385 -0.030319 -0.014305 0.160711 0.122708
UF2 -0.004431 -0.127076 -0.078823 -0.029049 0.000000 0.047472 0.048636 0.050439 0.055406 0.022550 0.134706 -0.044640
FOR -0.017299 0.016175 -0.031757 -0.067482 0.047472 0.000000 -0.034734 0.023896 -0.127020 -0.173953 0.019393 -0.030765
DA1 -0.106423 -0.060252 -0.094614 0.035595 0.048636 -0.034734 0.000000 -0.013858 0.052531 -0.041426 -0.193282 -0.072918
DA2 -0.052280 -0.009769 -0.053747 0.052385 0.050439 0.023896 -0.013858 0.000000 -0.008346 -0.003036 -0.177174 -0.080805
DA3 -0.013757 0.036346 -0.018999 -0.030319 0.055406 -0.127020 0.052531 -0.008346 0.000000 0.038825 -0.020653 -0.024038
EBA -0.095653 -0.165151 -0.169537 -0.014305 0.022550 -0.173953 -0.041426 -0.003036 0.038825 0.000000 -0.024648 0.011386
ST -0.045535 -0.020811 -0.032128 0.160711 0.134706 0.019393 -0.193282 -0.177174 -0.020653 -0.024648 0.000000 0.019302
MI -0.027609 -0.115360 -0.165671 0.122708 -0.044640 -0.030765 -0.072918 -0.080805 -0.024038 0.011386 0.019302 0.000000
recons = []
n_boot = 1000
resid_array = np.zeros((n_boot, len(sample_df.columns), len(sample_df.columns)))
for i in range(n_boot):
    recon_data = vae.generate(n_samples=len(sample_df))
    reconstructed_df = pd.DataFrame(recon_data, columns=sample_df.columns)
    resid = sample_df.corr() - reconstructed_df.corr()
    resid_array[i] = resid.values
    recons.append(reconstructed_df)

avg_resid = resid_array.mean(axis=0)
bootstrapped_resids = pd.DataFrame(avg_resid, columns=sample_df.columns, index=sample_df.columns)

plot_heatmap(bootstrapped_resids, title="""Expected Residuals \n Under Bootstrapped Reconstructions""")
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

Missing Data

sample_df_missing = sample_df.copy()

# Randomly pick 5% of the total elements
mask_remove = np.random.rand(*sample_df_missing.shape) < 0.03

# Set those elements to NaN
sample_df_missing[mask_remove] = np.nan
sample_df_missing.head()
JW1 JW2 JW3 UF1 UF2 FOR DA1 DA2 DA3 EBA ST MI
0 -1.241607 -0.218270 -1.101984 -0.092699 -0.799012 0.174001 -0.370272 -0.551770 -0.343902 1.277332 0.947278 0.495876
1 0.635936 -0.614378 0.382529 1.100494 1.126426 1.081677 0.292383 -0.173557 -0.208669 1.073014 0.355770 0.256730
2 -0.551499 -2.280220 -1.304824 -0.506253 -0.047927 -0.113652 1.094697 0.889694 0.986469 -0.734002 -0.599533 -0.168585
3 1.054921 1.173777 -0.138497 -0.186385 -1.286099 -0.618353 1.360622 1.045441 1.339666 -0.369791 0.451579 -1.216746
4 -0.886769 0.124020 -0.337509 0.162294 0.222886 0.051055 -0.122539 NaN -0.976451 1.344133 -0.079103 0.505976
import torch
import torch.nn as nn
import torch.nn.functional as F


class MissingDataDataset(Dataset):
    def __init__(self, x, mask):
        # x and mask are tensors of same shape
        self.x = x
        self.mask = mask
        
    def __len__(self):
        return self.x.shape[0]
    
    def __getitem__(self, idx):
        return self.x[idx], self.mask[idx]
import torch
import torch.nn as nn
import torch.nn.functional as F

class NumericVAE_missing(nn.Module):
    def __init__(self, n_features, hidden_dim=64, latent_dim=8):
        super().__init__()
        self.n_features = n_features

        # ---------- Learnable Imputation ----------
        # One learnable parameter per feature for missing values
        self.missing_embeddings = nn.Parameter(torch.zeros(n_features))

        # ---------- ENCODER ----------
        self.fc1_x = nn.Linear(n_features, hidden_dim)

        # Stronger mask encoder: 2-layer MLP
        self.fc1_mask = nn.Sequential(
            nn.Linear(n_features, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Combine feature and mask embeddings
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # ---------- DECODER ----------
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc_out_mu = nn.Linear(hidden_dim, n_features)
        self.fc_out_logvar = nn.Linear(hidden_dim, n_features)

    def encode(self, x, mask):
        # Impute missing values with learnable parameters
        x_filled = torch.where(
            torch.isnan(x),
            self.missing_embeddings.expand_as(x),
            x
        )

        # Encode features and mask separately
        h_x = F.relu(self.fc1_x(x_filled))
        h_mask = self.fc1_mask(mask)

        # Combine embeddings
        h = h_x + h_mask

        # Latent space
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc2(z))
        recon_mu = self.fc_out_mu(h)
        recon_logvar = self.fc_out_logvar(h)
        return recon_mu, recon_logvar

    def forward(self, x, mask):
        mu, logvar = self.encode(x, mask)
        z = self.reparameterize(mu, logvar)
        recon_mu, recon_logvar = self.decode(z)
        return (recon_mu, recon_logvar), mu, logvar

    def generate(self, n_samples=100):
        self.eval()
        with torch.no_grad():
            z = torch.randn(n_samples, self.fc_mu.out_features)
            recon_mu, recon_logvar = self.decode(z)
            return recon_mu + torch.exp(0.5 * recon_logvar) * torch.randn_like(recon_mu)
def vae_loss_with_missing(recon_mu, recon_logvar, x, mu, logvar, mask):
    recon_var = torch.exp(recon_logvar)
    
    # Only penalize reconstruction for observed data
    recon_loss = 0.5 * torch.sum(
        mask * (recon_logvar + (x - recon_mu)**2 / recon_var)
    )
    
    # KL divergence remains the same
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss, kl_div


def beta_annealing(epoch, max_beta=1.0, anneal_epochs=100):

    beta = min(max_beta, max_beta * epoch / anneal_epochs)
    return beta
# Prepare data: impute missing with 0 for now and create mask
mask = ~sample_df_missing.isna()
mask_tensor = torch.tensor(mask.values, dtype=torch.float32)
x_tensor = torch.tensor(sample_df_missing.fillna(0).values, dtype=torch.float32)

batch_size = 32  # or any batch size you prefer
dataset = MissingDataDataset(x_tensor, mask_tensor)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Create model & optimizer
model = NumericVAE_missing(n_features=x_tensor.shape[1])
optimizer = optim.Adam(model.parameters(), lr=1e-3)

n_epochs = 3000
for epoch in range(n_epochs):
    beta = beta_annealing(epoch, max_beta=1.0, anneal_epochs=50)
    model.train()
    
    total_loss = 0
    for x_batch, mask_batch in data_loader:
        optimizer.zero_grad()
        (recon_mu, recon_logvar), mu, logvar = model(x_batch, mask_batch)
        recon_loss, kl_loss = vae_loss_with_missing(recon_mu, recon_logvar, x_batch, mu, logvar, mask_batch)
        loss = recon_loss + beta * kl_loss
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * x_batch.size(0)  # sum over batch

    avg_loss = total_loss / len(dataset)  # average loss per sample
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{n_epochs} - Avg Loss: {avg_loss:.4f}")
recons = []
n_boot = 500
resid_array = np.zeros((n_boot, len(sample_df_missing.columns), len(sample_df_missing.columns)))
for i in range(500):
    recon_data = model.generate(n_samples=len(sample_df_missing))
    reconstructed_df = pd.DataFrame(recon_data, columns=sample_df_missing.columns)
    resid = sample_df.corr() - reconstructed_df.corr()
    resid_array[i] = resid.values
    recons.append(reconstructed_df)

avg_resid = resid_array.mean(axis=0)
bootstrapped_resids = pd.DataFrame(avg_resid, columns=sample_df_missing.columns, index=sample_df_missing.columns)

plot_heatmap(bootstrapped_resids, title="""Expected Residuals \n Under Bootstrapped Reconstructions""")
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

recon_data = model.generate(n_samples=len(sample_df_missing))

# Rebuild imputed DataFrame
imputed_array = sample_df_missing.to_numpy().copy()
imputed_array[mask_remove] = recon_data[mask_remove]
imputed_df = pd.DataFrame(imputed_array, columns=sample_df_missing.columns)

print("\nSample imputed data:")
print(imputed_df.head())

Sample imputed data:
        JW1       JW2       JW3       UF1       UF2       FOR       DA1  \
0 -1.241607 -0.218270 -1.101984 -0.092699 -0.799012  0.174001 -0.370272   
1  0.635936 -0.614378  0.382529  1.100494  1.126426  1.081677  0.292383   
2 -0.551499 -2.280220 -1.304824 -0.506253 -0.047927 -0.113652  1.094697   
3  1.054921  1.173777 -0.138497 -0.186385 -1.286099 -0.618353  1.360622   
4 -0.886769  0.124020 -0.337509  0.162294  0.222886  0.051055 -0.122539   

        DA2       DA3       EBA        ST        MI  
0 -0.551770 -0.343902  1.277332  0.947278  0.495876  
1 -0.173557 -0.208669  1.073014  0.355770  0.256730  
2  0.889694  0.986469 -0.734002 -0.599533 -0.168585  
3  1.045441  1.339666 -0.369791  0.451579 -1.216746  
4  0.133379 -0.976451  1.344133 -0.079103  0.505976  
plot_heatmap(sample_df.corr() - imputed_df.corr())
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

fig, axs = plt.subplots(1,2 ,figsize=(10, 30))
axs = axs.flatten()
plot_heatmap(sample_df_missing.head(50).fillna(99), vmin=-0, vmax=99, ax=axs[0], colorbar=False)
axs[0].set_title("Missng Data", fontsize=20)
plot_heatmap(imputed_df.head(50), vmin=-2, vmax=2, ax=axs[1], colorbar=False)
axs[1].set_title("Imputed Data", fontsize=20);
plt.tight_layout()
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

Bayesian Inference

def make_pymc_model(sample_df):
    coords = {'features': sample_df.columns,
            'features1': sample_df.columns ,
            'obs': range(len(sample_df))}

    with pm.Model(coords=coords) as model:
        # Priors
        mus = pm.Normal("mus", 0, 1, dims='features')
        chol, _, _ = pm.LKJCholeskyCov("chol", n=12, eta=1.0, sd_dist=pm.HalfNormal.dist(1))
        cov = pm.Deterministic('cov', pm.math.dot(chol, chol.T), dims=('features', 'features1'))

        pm.MvNormal('likelihood', mus, cov=cov, observed=sample_df.values, dims=('obs', 'features'))
        
        idata = pm.sample_prior_predictive()
        idata.extend(pm.sample(random_seed=120))
        pm.sample_posterior_predictive(idata, extend_inferencedata=True)

    return idata, model 

idata, model = make_pymc_model(sample_df)
pm.model_to_graphviz(model)

import arviz as az

expected_corr = pd.DataFrame(az.summary(idata, var_names=['chol_corr'])['mean'].values.reshape((12, 12)), columns=sample_df.columns, index=sample_df.columns)

resids = sample_df.corr() - expected_corr
plot_heatmap(resids)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:991: RuntimeWarning: invalid value encountered in scalar divide
  varsd = varvar / evar / 4
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

Missing Data

idata_missing, model_missing = make_pymc_model(sample_df_missing)
pm.model_to_graphviz(model_missing)

expected_corr = pd.DataFrame(az.summary(idata_missing, var_names=['chol_corr'])['mean'].values.reshape((12, 12)), columns=sample_df.columns, index=sample_df.columns)

resids = sample_df.corr() - expected_corr
plot_heatmap(resids)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:596: RuntimeWarning: invalid value encountered in scalar divide
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
/Users/nathanielforde/mambaforge/envs/pytorch-env/lib/python3.10/site-packages/arviz/stats/diagnostics.py:991: RuntimeWarning: invalid value encountered in scalar divide
  varsd = varvar / evar / 4
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:19: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_xticklabels(data.columns)
/var/folders/__/ng_3_9pn1f11ftyml_qr69vh0000gn/T/ipykernel_66216/2568263314.py:21: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  ax.set_yticklabels(data.index)

Citation

BibTeX citation:
@online{forde2025,
  author = {Forde, Nathaniel},
  title = {Amortized {Bayesian} {Inference} with {PyTorch}},
  date = {2025-07-25},
  langid = {en},
  abstract = {The cost of generating new sample data can be prohibitive.
    There is a secondary but different cost which attaches to the
    “construction” of novel data. Principal Components Analysis can be
    seen as a technique to optimally reconstruct a complex multivariate
    data set from a lower level compressed dimensional space.
    Variational auto-encoders allow us to achieve yet more flexible
    reconstruction results in non-linear cases. Drawing a new sample
    from the posterior predictive distribution of Bayesian models
    similarly supplies us with insight in the variability of realised
    data. Both methods assume a latent model of the data generating
    process that aims to leverage a compressed representation of the
    data. These are different heuristics with different consequences for
    how we understand the variability in the world. Amortized Bayesian
    inference seeks to unite the two heuristics.}
}
For attribution, please cite this work as:
Forde, Nathaniel. 2025. “Amortized Bayesian Inference with PyTorch.” July 25, 2025.